{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.datasets as dsets\n",
    "from torch.autograd import Variable\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 使用 MNIST 資料集  (載入資料)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train_dataset = dsets.MNIST(\n",
    "    root = './data',\n",
    "    train = True,\n",
    "    transform = transforms.ToTensor(),\n",
    "    download = True\n",
    ")\n",
    "test_dataset = dsets.MNIST(\n",
    "    root = './data',\n",
    "    train = False,\n",
    "    transform = transforms.ToTensor()\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_dataset : torch.Size([60000, 28, 28])\n"
     ]
    }
   ],
   "source": [
    "print('train_dataset :', train_dataset.train_data.size())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_dataset : torch.Size([60000])\n"
     ]
    }
   ],
   "source": [
    "print('train_dataset :', train_dataset.train_labels.size())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test_dataset : torch.Size([10000, 28, 28])\n"
     ]
    }
   ],
   "source": [
    "print('test_dataset :', test_dataset.test_data.size())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test_dataset : torch.Size([10000])\n"
     ]
    }
   ],
   "source": [
    "print('test_dataset :', test_dataset.test_labels.size())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "batch_size = 100\n",
    "n_iters = 1000\n",
    "num_epochs = n_iters / (len(train_dataset) / batch_size)\n",
    "num_epochs = int(num_epochs)\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(dataset = train_dataset,\n",
    "                                           batch_size = batch_size,\n",
    "                                           shuffle = True)\n",
    "test_loader = torch.utils.data.DataLoader(dataset = test_dataset,\n",
    "                                           batch_size = batch_size,\n",
    "                                           shuffle = False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create Model !\n",
    "![](./imgs/rnn.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class RNNModel(nn.Module):\n",
    "    def __init__(self, input_dim, output_dim):\n",
    "        super(RNNModel, self).__init__()   \n",
    "        \n",
    "        self.rnn = nn.RNN(\n",
    "            input_size = input_dim,\n",
    "            hidden_size = 100,\n",
    "            num_layers = 1,\n",
    "            batch_first=True,\n",
    "            nonlinearity='relu')\n",
    "        \n",
    "        self.fc = nn.Linear(100, 10)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        # (layer_dim, batch_size, hidden_dim)\n",
    "        h0 = None\n",
    "        # another hidden state example :\n",
    "        # h0 = Variable(torch.zeros(1, x.size(0), 100))\n",
    "        \n",
    "        out, hn = self.rnn(x, h0)\n",
    "        # \"out\" dim : (100, 28, 100)\n",
    "        # \"-1\" means the last time step\n",
    "        \n",
    "        out = self.fc(out[:, -1, :])  # (100, 100)\n",
    "        # \"out\" dim : (100, 10)\n",
    "        \n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "input_dim = 28\n",
    "output_dim = 10\n",
    "model = RNNModel(input_dim, output_dim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "cost_fun = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "lr = 0.1\n",
    "opt = torch.optim.SGD(model.parameters(), lr = lr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](./imgs/rnn_weights.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration: 5. Loss: 2.30782413482666. Acc: 8.02\n",
      "Iteration: 10. Loss: 2.300601005554199. Acc: 9.71\n",
      "Iteration: 15. Loss: 2.3077762126922607. Acc: 11.56\n",
      "Iteration: 20. Loss: 2.301717758178711. Acc: 9.84\n",
      "Iteration: 25. Loss: 2.304553270339966. Acc: 10.5\n",
      "Iteration: 30. Loss: 2.303130626678467. Acc: 8.98\n",
      "Iteration: 35. Loss: 2.3052358627319336. Acc: 9.28\n",
      "Iteration: 40. Loss: 2.2994351387023926. Acc: 9.72\n",
      "Iteration: 45. Loss: 2.302813768386841. Acc: 12.43\n",
      "Iteration: 50. Loss: 2.291151762008667. Acc: 11.35\n",
      "Iteration: 55. Loss: 2.296333074569702. Acc: 11.71\n",
      "Iteration: 60. Loss: 2.296353816986084. Acc: 11.35\n",
      "Iteration: 65. Loss: 2.291320323944092. Acc: 11.35\n",
      "Iteration: 70. Loss: 2.301290512084961. Acc: 11.35\n",
      "Iteration: 75. Loss: 2.298933744430542. Acc: 12.33\n",
      "Iteration: 80. Loss: 2.2862627506256104. Acc: 11.37\n",
      "Iteration: 85. Loss: 2.2919387817382812. Acc: 11.36\n",
      "Iteration: 90. Loss: 2.281371593475342. Acc: 11.4\n",
      "Iteration: 95. Loss: 2.295684337615967. Acc: 11.4\n",
      "Iteration: 100. Loss: 2.2877917289733887. Acc: 14.46\n",
      "Iteration: 105. Loss: 2.2995691299438477. Acc: 12.19\n",
      "Iteration: 110. Loss: 2.2970426082611084. Acc: 12.76\n",
      "Iteration: 115. Loss: 2.295147657394409. Acc: 13.46\n",
      "Iteration: 120. Loss: 2.294701337814331. Acc: 14.29\n",
      "Iteration: 125. Loss: 2.303502321243286. Acc: 17.52\n",
      "Iteration: 130. Loss: 2.2842724323272705. Acc: 19.08\n",
      "Iteration: 135. Loss: 2.2786972522735596. Acc: 18.83\n",
      "Iteration: 140. Loss: 2.2923078536987305. Acc: 18.55\n",
      "Iteration: 145. Loss: 2.2706942558288574. Acc: 19.87\n",
      "Iteration: 150. Loss: 2.2823739051818848. Acc: 19.87\n",
      "Iteration: 155. Loss: 2.2596092224121094. Acc: 19.84\n",
      "Iteration: 160. Loss: 2.291473627090454. Acc: 19.18\n",
      "Iteration: 165. Loss: 2.2532148361206055. Acc: 19.75\n",
      "Iteration: 170. Loss: 2.2626466751098633. Acc: 19.71\n",
      "Iteration: 175. Loss: 2.2123641967773438. Acc: 21.1\n",
      "Iteration: 180. Loss: 2.248894453048706. Acc: 22.99\n",
      "Iteration: 185. Loss: 2.2449729442596436. Acc: 22.78\n",
      "Iteration: 190. Loss: 2.2134130001068115. Acc: 23.59\n",
      "Iteration: 195. Loss: 2.2094531059265137. Acc: 23.94\n",
      "Iteration: 200. Loss: 2.178576946258545. Acc: 24.47\n",
      "Iteration: 205. Loss: 2.159407138824463. Acc: 20.27\n",
      "Iteration: 210. Loss: 2.0760600566864014. Acc: 25.04\n",
      "Iteration: 215. Loss: 2.1000783443450928. Acc: 28.37\n",
      "Iteration: 220. Loss: 2.0461995601654053. Acc: 18.68\n",
      "Iteration: 225. Loss: 2.0294673442840576. Acc: 24.68\n",
      "Iteration: 230. Loss: 2.125579357147217. Acc: 28.25\n",
      "Iteration: 235. Loss: 1.9273933172225952. Acc: 32.0\n",
      "Iteration: 240. Loss: 1.8278754949569702. Acc: 23.77\n",
      "Iteration: 245. Loss: 2.373661756515503. Acc: 4.98\n",
      "Iteration: 250. Loss: 2.177743911743164. Acc: 21.2\n",
      "Iteration: 255. Loss: 2.2501838207244873. Acc: 11.63\n",
      "Iteration: 260. Loss: 2.0116240978240967. Acc: 24.53\n",
      "Iteration: 265. Loss: 2.027683973312378. Acc: 31.52\n",
      "Iteration: 270. Loss: 2.28973650932312. Acc: 32.38\n",
      "Iteration: 275. Loss: 1.9849239587783813. Acc: 28.78\n",
      "Iteration: 280. Loss: 1.9157863855361938. Acc: 32.29\n",
      "Iteration: 285. Loss: 2.2078731060028076. Acc: 15.88\n",
      "Iteration: 290. Loss: 2.009370803833008. Acc: 30.54\n",
      "Iteration: 295. Loss: 1.9607937335968018. Acc: 27.74\n",
      "Iteration: 300. Loss: 1.7221893072128296. Acc: 26.08\n",
      "Iteration: 305. Loss: 1.6460521221160889. Acc: 41.63\n",
      "Iteration: 310. Loss: 1.8307888507843018. Acc: 40.74\n",
      "Iteration: 315. Loss: 2.4728074073791504. Acc: 26.53\n",
      "Iteration: 320. Loss: 1.8371131420135498. Acc: 40.01\n",
      "Iteration: 325. Loss: 1.7560515403747559. Acc: 37.52\n",
      "Iteration: 330. Loss: 1.459166407585144. Acc: 50.35\n",
      "Iteration: 335. Loss: 2.288578510284424. Acc: 14.67\n",
      "Iteration: 340. Loss: 2.1267740726470947. Acc: 30.35\n",
      "Iteration: 345. Loss: 1.9202020168304443. Acc: 35.86\n",
      "Iteration: 350. Loss: 1.767936110496521. Acc: 37.17\n",
      "Iteration: 355. Loss: 1.549329400062561. Acc: 38.28\n",
      "Iteration: 360. Loss: 1.7116602659225464. Acc: 44.09\n",
      "Iteration: 365. Loss: 1.6229132413864136. Acc: 40.09\n",
      "Iteration: 370. Loss: 1.6109799146652222. Acc: 48.05\n",
      "Iteration: 375. Loss: 1.426284909248352. Acc: 42.53\n",
      "Iteration: 380. Loss: 1.6249094009399414. Acc: 44.6\n",
      "Iteration: 385. Loss: 1.2747416496276855. Acc: 49.23\n",
      "Iteration: 390. Loss: 1.6980466842651367. Acc: 47.19\n",
      "Iteration: 395. Loss: 1.5829755067825317. Acc: 49.92\n",
      "Iteration: 400. Loss: 1.3933273553848267. Acc: 48.06\n",
      "Iteration: 405. Loss: 1.5726191997528076. Acc: 48.89\n",
      "Iteration: 410. Loss: 1.333174228668213. Acc: 51.36\n",
      "Iteration: 415. Loss: 1.277411937713623. Acc: 49.81\n",
      "Iteration: 420. Loss: 1.623317003250122. Acc: 50.01\n",
      "Iteration: 425. Loss: 1.419437050819397. Acc: 44.04\n",
      "Iteration: 430. Loss: 1.3354264497756958. Acc: 52.27\n",
      "Iteration: 435. Loss: 1.2488932609558105. Acc: 48.49\n",
      "Iteration: 440. Loss: 1.4373506307601929. Acc: 39.44\n",
      "Iteration: 445. Loss: 1.6383869647979736. Acc: 36.46\n",
      "Iteration: 450. Loss: 1.8007413148880005. Acc: 43.34\n",
      "Iteration: 455. Loss: 1.390915036201477. Acc: 48.97\n",
      "Iteration: 460. Loss: 1.4800230264663696. Acc: 51.43\n",
      "Iteration: 465. Loss: 1.4787124395370483. Acc: 51.93\n",
      "Iteration: 470. Loss: 3.796271324157715. Acc: 25.97\n",
      "Iteration: 475. Loss: 1.71541166305542. Acc: 43.23\n",
      "Iteration: 480. Loss: 1.5629210472106934. Acc: 44.07\n",
      "Iteration: 485. Loss: 1.4443551301956177. Acc: 44.69\n",
      "Iteration: 490. Loss: 1.5549191236495972. Acc: 50.7\n",
      "Iteration: 495. Loss: 1.4658691883087158. Acc: 49.97\n",
      "Iteration: 500. Loss: 1.7749496698379517. Acc: 34.28\n",
      "Iteration: 505. Loss: 1.573665738105774. Acc: 52.29\n",
      "Iteration: 510. Loss: 1.354344367980957. Acc: 47.43\n",
      "Iteration: 515. Loss: 1.4813590049743652. Acc: 53.13\n",
      "Iteration: 520. Loss: 1.3405498266220093. Acc: 42.66\n",
      "Iteration: 525. Loss: 1.3963165283203125. Acc: 49.75\n",
      "Iteration: 530. Loss: 1.217664122581482. Acc: 54.03\n",
      "Iteration: 535. Loss: 1.3936368227005005. Acc: 48.0\n",
      "Iteration: 540. Loss: 1.6639338731765747. Acc: 37.92\n",
      "Iteration: 545. Loss: 1.3928662538528442. Acc: 53.88\n",
      "Iteration: 550. Loss: 1.2078827619552612. Acc: 44.64\n",
      "Iteration: 555. Loss: 1.6510934829711914. Acc: 44.75\n",
      "Iteration: 560. Loss: 1.2386350631713867. Acc: 52.9\n",
      "Iteration: 565. Loss: 1.3392354249954224. Acc: 47.78\n",
      "Iteration: 570. Loss: 1.360122799873352. Acc: 44.4\n",
      "Iteration: 575. Loss: 1.3723400831222534. Acc: 55.03\n",
      "Iteration: 580. Loss: 1.1518886089324951. Acc: 52.02\n",
      "Iteration: 585. Loss: 1.242003321647644. Acc: 51.5\n",
      "Iteration: 590. Loss: 1.5706968307495117. Acc: 52.95\n",
      "Iteration: 595. Loss: 1.424400806427002. Acc: 53.52\n",
      "Iteration: 600. Loss: 1.2034755945205688. Acc: 56.05\n"
     ]
    }
   ],
   "source": [
    "iters = 0\n",
    "for epoch in range(num_epochs):\n",
    "    for i, (images, labels) in enumerate(train_loader):\n",
    "        # (100, 1, 28, 28) > (100, 28, 28)\n",
    "        images = Variable(images.view(-1, 28, 28))\n",
    "        labels = Variable(labels)\n",
    "        \n",
    "        opt.zero_grad()\n",
    "        \n",
    "        outputs = model(images)\n",
    "        \n",
    "        loss = cost_fun(outputs, labels)\n",
    "        loss.backward()\n",
    "        opt.step()\n",
    "        \n",
    "        iters += 1\n",
    "        \n",
    "        if iters % 5 == 0:\n",
    "            correct = 0\n",
    "            total = 0\n",
    "            for images, labels in test_loader:\n",
    "                images = Variable(images.view(-1, 28, 28))\n",
    "                outputs = model(images)\n",
    "                _, predicted = torch.max(outputs.data, 1)\n",
    "                total += labels.size(0)\n",
    "                \n",
    "                correct += (predicted == labels).sum()\n",
    "            acc = 100 * correct / total\n",
    "            print('Iteration: {}. Loss: {}. Acc: {}'.format(iters, loss.data[0], acc))\n",
    "                "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "Rerference :     \n",
    "    - 莫煩 pytorch : https://morvanzhou.github.io/tutorials/machine-learning/torch/\n",
    "    \n",
    "    - 官方 tutorail : http://pytorch.org/tutorials/\n",
    "    \n",
    "    - https://blog.csdn.net/LiuPeiP_VIPL/article/details/75000742\n",
    "    \n",
    "    - https://www.youtube.com/watch?v=xCGidAeyS4M"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
